import pickle
import matplotlib.pyplot as plt
import numpy as np

path = 'D:/1_Documents/Courses/04 Graduate I/00Project/xcb_Deeplearning/Books/fishbook/【源代码】深度学习入门：基于Python的理论与实现_20240716/dataset/mnist.pkl'  # path='/ro
# 读取 mnist.pkl
with open(path, "rb") as f:
    dataset = pickle.load(f)

# 打印结构信息
print("数据集的 key:", dataset.keys())
print("训练图像 shape:", dataset['train_img'].shape)
print("训练标签 shape:", dataset['train_label'].shape)
print("测试图像 shape:", dataset['test_img'].shape)
print("测试标签 shape:", dataset['test_label'].shape)

# 取一个样本出来看看
idx = 0  # 可以改成别的，比如 123
img = dataset['train_img'][idx]
label = dataset['train_label'][idx]

# 因为保存时是 (784,) 一维，需要 reshape 回 (28, 28)
img = img.reshape(28, 28)

print("样本标签:", label)

# 可视化显示图片
plt.imshow(img, cmap="gray")
plt.title(f"Label: {label}")
plt.axis("off")
plt.show()
